import cvxpy as cp
import numpy as np
import argparse

# let Theta = [0,1]^d
# s(theta) = uniform
def get_params_from_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--n", type=int, default=50)
    parser.add_argument("--t", type=int, default=100)
    parser.add_argument("--d", type=int, default=10)
    parser.add_argument("--sd", type=int, default=1)
    args = parser.parse_args()
    n, t, d, sd = args.n, args.t, args.d, args.sd
    return n, t, d, sd

def eval(cc, dd, ll, rr):
    return 0.5 * cc * (rr**2 - ll**2) + dd * (rr - ll)

def square_integrate(cc, dd, ll, rr): 
    ''' integrate (cc * theta + dd)^2 on [ll, rr] '''
    term1 = (1/3) * (cc**2) * (rr**3 - ll**3) 
    term2 = cc * dd * (rr**2 - ll**2)
    term3 = (dd**2) * (rr - ll)
    return term1 + term2 + term3

def create_one_dim_inf_dim_market(n, sd=2022, normalize=True):
    # n, sd = 100, 2022
    curr_rand_state = np.random.get_state()
    np.random.seed(sd)
    B = np.random.uniform(size=n)
    B /= B.sum()

    # valuations: v[i](theta) = c[i] * theta + d[i]

    if normalize:
        d = np.random.uniform(low=0.0, high=2.0, size=n)
        c = (1 - d) * 2
    else: 
        raise NotImplementedError()

    np.random.set_state(curr_rand_state)
    return B, c, d

def move_knife(i, ui, l, c, d):
    ''' given v[i], utility value and left endpoint, find right endpoint '''
    aa, bb, cc = c[i]/2, d[i], - (c[i]/2 * l**2 + d[i] * l + ui)
    return (-bb + np.sqrt(bb**2 - 4 * aa * cc))/(2*aa)

def solve_one_dim_inf_dim_market(B, c, d, return_allocation=False):
    n = len(B)
    sorted_indices = np.argsort(d)[::-1]
    place_of_buyer = [-1] * n
    for ii, jj in enumerate(sorted_indices):
        place_of_buyer[jj] = ii

    B, c, d = B[sorted_indices], c[sorted_indices], d[sorted_indices]

    # build EG using CVXPY
    u = cp.Variable(n, nonneg=True)
    z, w, s, t = cp.Variable(n-1), cp.Variable(n-1), cp.Variable(n-1), cp.Variable(n-1)
    G = [np.array([ [d[i], c[i]/2], [-d[i+1], -c[i+1]/2] ]) for i in range(n-1)]
    objective = cp.Maximize(cp.sum(B * cp.log(u)))
    # linear constraints for slack variables z, w
    constraints = [z>=0, z<=1, w<=0, w>=-1] 

    # linear constraints u together with z, w
    constraints.append(u[0] <= z[0])
    for i in range(1, n-1):
        constraints.append(u[i] <= z[i] + w[i-1])
    constraints.append(u[n-1] <= 1+w[n-2])
    # second order cone constraints
    for i in range(n-1):
        constraints.append(z[i]+w[i] >= 0)
        constraints.append(G[i][0,0]*s[i]+G[i][0,1]*t[i] == z[i])
        constraints.append(G[i][1,0]*s[i]+G[i][1,1]*t[i] == w[i])
        constraints.append(s[i]**2 <= t[i])

    prob = cp.Problem(objective, constraints)
    prob.solve()
    assert(prob.status=='optimal')
    u, z, w, s, t = u.value, z.value, w.value, s.value, t.value
    # revert to original order
    u_eq, c, d = u[place_of_buyer], c[place_of_buyer], d[place_of_buyer]
    if return_allocation:
        bpts = [0]
        for j in range(n-1):
            i = sorted_indices[j] # get the j-th buyer i
            bpts.append(move_knife(i, u_eq[i], bpts[-1], c, d))
        bpts.append(1)
        return u_eq, place_of_buyer, bpts
    return u_eq

def create_inf_dim_market(n, d=5, sd=2022, normalize=False):
    # supplies s = uniform on [0,1]^d
    curr_rand_state = np.random.get_state()
    np.random.seed(sd)
    B = np.random.uniform(size=n)
    B /= B.sum()

    # valuations
    alpha = np.random.normal(size=(n,d))
    const = np.array( 
        [0.005 - alpha[i][alpha[i]<0].sum() for i in range(n)] 
    )

    if normalize: # make them integrate to 1
        integrals = alpha.sum(axis=1)/2 + const
        for i in range(n):
            alpha[i] /= integrals[i]
            const[i] /= integrals[i]

    np.random.set_state(curr_rand_state)
    return B, alpha, const

def sample_market_instance_one_dim(c, d, t):
    n = len(c)
    theta_list = np.random.uniform(size=t)
    v_mat = np.array([ c[i] * theta_list + d[i] for i in range(n)])
    return v_mat

def sample_market_instance(alpha, const, t):
    # v_func = lambda i, theta: alpha[i] @ theta + const[i]
    # sample some items
    n, d = alpha.shape
    theta_list = np.random.uniform(size=(t, d))
    
    v_mat = np.array([ 
        (alpha[i] @ theta_list.T) + const[i] for i in range(n)
    ])
    
    return v_mat

def check_eq(B, v_mat, x_eq, s, p_eq):
    n, t = v_mat.shape
    u_eq = (v_mat * x_eq).sum(axis=1)
    mkt_clr_res = np.linalg.norm(x_eq.sum(axis=0)-s)
    best_pick = np.argmax(v_mat / p_eq, axis=1)
    u_best = np.array([
        B[i]/p_eq[j] * v_mat[i,j] for i, j in enumerate(best_pick)
    ])
    buyer_subopt_residual = np.linalg.norm(np.maximum(u_best - u_eq, 0))
    print('===================================')
    print(f"n, t={n}, {t}")
    print(f"market clearance residual = {mkt_clr_res}")
    print(f"buyer optimality residual = {buyer_subopt_residual}")
    print('===================================')
    return max(mkt_clr_res, buyer_subopt_residual) <= 5e-5

def compute_b_from_x(x, v, B):
    tt = v*x
    return ((tt.T / np.sum(tt, 1)) * B).T

def pr(v, B, max_iter=5000):
    ''' unit supply for each item '''
    n, m = v.shape
    x = np.multiply((B/np.sum(B)), np.ones(shape=(m,n))).T
    b = v*x
    for iter in range(max_iter): # proportional response dynamics
        p = np.sum(b, axis=0) # compute prices
        x = b/p # compute allocations
        b = compute_b_from_x(x, v, B) # new bids
    return x, p

def compute_me_fin_dim(v, B, s=None, max_iter=1000):
    n, m = v.shape
    if B is None: 
        B = np.ones(n)/n
    if s is None:
        m = v.shape[1]
        s = np.ones(m)/m
    # v_total = v@s
    # v_normalized = (v.T / v_total).T

    # scale the item values
    x, p = pr(v * s, B, max_iter)
    return x*s, p/s

if __name__ == '__main__':
    n = 50
    t = 500
    d = 5
    s = np.ones(shape=t)

    B, alpha, const = create_inf_dim_market(n, d)
    v_mat = sample_market_instance(alpha, const, n, t)
    
    x_eq, p_eq = compute_me_fin_dim(v_mat, B, s, max_iter=10000)
    u_eq = (v_mat * x_eq).sum(axis=1)

    # sanity check
    beta = B / u_eq
    pp = np.array( [ (beta * v_mat[:, j]).max() for j in range(t) ] )
    np.linalg.norm(p_eq - pp)